{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import PreTrainedTokenizerFast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = PreTrainedTokenizerFast.from_pretrained('./model_save/tokenizer')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 训练tokenizer（可选）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tokenizers\n",
    "from tokenizers import Tokenizer, decoders\n",
    "from tokenizers.models import BPE\n",
    "from tokenizers.trainers import BpeTrainer\n",
    "from tokenizers.pre_tokenizers import Punctuation, Digits, Metaspace, ByteLevel\n",
    "from tokenizers.normalizers import NFKC \n",
    "from rich import progress"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. 定义tokenizer训练语料来源"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cropus_file =  './data/wiki.simple.txt'\n",
    "tokenizer_save_path = './model_save/hf_bpe_tokenizer.josn'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. 训练tokenizer的函数\n",
    "`get_training_corpus`函数将多个短拒绝拼接成长度大于`chunk_len=2048`句子，每次迭代返回`buffer_size=1000`个这样的长句子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_my_huggingface_wiki_tokenizer(max_train_line: int=None, token_type: str='char') -> None:\n",
    "    '''\n",
    "    训练tokenizer with huggingface，至少需要32G内存，运行大概需要半个小时。\n",
    "    '''\n",
    "\n",
    "    # if not exists(tokenizer_save_path): mkdir(tokenizer_save_path)\n",
    "\n",
    "    def get_training_corpus(buffer_size: int=1000, chunk_len: int=2048) -> list:\n",
    "        '''\n",
    "        一个文本块大小2048\n",
    "        '''\n",
    "        line_cnt = 0\n",
    "        buffer = []\n",
    "        with open(cropus_file, 'r', encoding='utf-8') as f_read:\n",
    "            cur_chunk_txt, txt_len = [], 0\n",
    "            for line in f_read:\n",
    "\n",
    "                cur_chunk_txt.append(line)\n",
    "                txt_len += len(line)\n",
    "                line_cnt += 1\n",
    "\n",
    "                if txt_len >= chunk_len:\n",
    "                    buffer.append(\n",
    "                        ''.join(cur_chunk_txt)\n",
    "                    )\n",
    "                    cur_chunk_txt, txt_len = [], 0\n",
    "                \n",
    "                if len(buffer) >= buffer_size:\n",
    "                    yield buffer\n",
    "                    buffer = []\n",
    "\n",
    "                if isinstance(max_train_line, int) and line_cnt > max_train_line: break\n",
    "                \n",
    "            # yield last\n",
    "            if len(buffer) > 0: yield buffer        \n",
    "\n",
    "    special_tokens = [\"[PAD]\",\"[EOS]\",\"[SEP]\",\"[BOS]\", \"[CLS]\", \"[MASK]\", \"[UNK]\"]\n",
    "    \n",
    "    if token_type ==' char':\n",
    "        model = BPE(unk_token=\"[UNK]\")\n",
    "        tokenizer = Tokenizer(model)\n",
    "        \n",
    "        \n",
    "\n",
    "        # 用兼容等价分解合并对utf编码进行等价组合，比如全角A转换为半角A\n",
    "        tokenizer.normalizer = tokenizers.normalizers.Sequence([NFKC()])\n",
    "\n",
    "        # 标点符号，数字，及Metaspace预分割（否则decode出来没有空格）\n",
    "        tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Sequence(\n",
    "            [Punctuation(), Digits(individual_digits=True), Metaspace()]\n",
    "        )\n",
    "\n",
    "        tokenizer.add_special_tokens(special_tokens)\n",
    "        tokenizer.decoder = decoders.Metaspace()\n",
    "    elif token_type ==' byte':\n",
    "        # byte BPE n不需要unk_token\n",
    "        model = BPE() \n",
    "        tokenizer = Tokenizer(model)\n",
    "        tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=True)\n",
    "\n",
    "        tokenizer.add_special_tokens(special_tokens)\n",
    "        tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False, use_regex=True)\n",
    "        tokenizer.post_processor = tokenizers.processors.ByteLevel(trim_offsets=False)\n",
    "    else:\n",
    "        raise Exception('token type must be `char` or `byte`')\n",
    "\n",
    "    trainer = BpeTrainer(vocab_size=40960, min_frequency=100, show_progress=True, special_tokens=special_tokens)\n",
    "    tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)\n",
    "\n",
    "    # add \\t \\n \n",
    "    if '\\t' not in tokenizer.get_vocab():\n",
    "        tokenizer.add_tokens(['\\t'])\n",
    "    if '\\n' not in tokenizer.get_vocab():\n",
    "        tokenizer.add_tokens(['\\n'])\n",
    "\n",
    "    tokenizer.save(tokenizer_save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. 开始训练tokenizer\n",
    "1亿个字符至少需要`32G`内存（其实`32G`还是不太够，会频繁触发swap），CPU`13600k`训练时长大概1个小时。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_my_huggingface_wiki_tokenizer(token_type='byte')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. 将训练的tokenizer转换为PreTrainedTokenizerFast并保存\n",
    "转换是为了方便作为`AutoTokenizer`传到其他`huggingface`组件使用。\n",
    "\n",
    "转换时要手动指定`pad_token`、`eos_token`等特殊token，因为它不指定你原来的tokenizer中哪些字符是这些特殊字符"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "slow_tokenizer = Tokenizer.from_file(tokenizer_save_path)\n",
    "tokenizer = PreTrainedTokenizerFast(\n",
    "    tokenizer_object=slow_tokenizer,\n",
    "    unk_token=\"[UNK]\",\n",
    "    pad_token=\"[PAD]\",\n",
    "    cls_token=\"[CLS]\",\n",
    "    sep_token=\"[SEP]\",\n",
    "    mask_token=\"[MASK]\",\n",
    "    bos_token='[BOS]',\n",
    "    eos_token='[EOS]',                  \n",
    ")\n",
    "tokenizer.save_pretrained('./model_save/fast_tokenizer/')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
